# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Dict, Callable, Optional

import jax

import brainpy.math as bm
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.dnn import linear
from brainpy.dyn import _docs
from brainpy.dyn import synapses
from brainpy.dyn.base import NeuDyn
from brainpy.dynold.synouts import MgBlock, CUBA
from brainpy.initialize import Initializer
from brainpy.types import ArrayType
from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre

__all__ = [
    'Delta',
    'Exponential',
    'DualExponential',
    'Alpha',
    'NMDA',
]


class Delta(TwoEndConn):
    r"""Voltage Jump Synapse Model, or alias of Delta Synapse Model.

    **Model Descriptions**

    .. math::

        I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \mathrm{STP} * \delta(t-t_j-D)

    where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
    :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
    :math:`C` the set of neurons connected to the post-synaptic neuron,
    :math:`D` the transmission delay of chemical synapses,
    and :math:`\mathrm{STP}` the short-term plasticity effect.
    For simplicity, the rise and decay phases of post-synaptic currents are
    omitted in this model.

    **Model Examples**

    >>> import brainpy as bp
    >>> from brainpy import synapses, neurons
    >>> import matplotlib.pyplot as plt
    >>>
    >>> neu1 = neurons.LIF(1)
    >>> neu2 = neurons.LIF(1)
    >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), g_max=5.)
    >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
    >>>
    >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.), ('post.input', 10.)], monitors=['pre.V', 'post.V', 'pre.spike'])
    >>> runner.run(150.)
    >>>
    >>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8)
    >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
    >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
    >>> plt.xlim(40, 150)
    >>> plt.legend()
    >>> plt.show()

    Parameters::

    pre: NeuDyn
      The pre-synaptic neuron group.
    post: NeuDyn
      The post-synaptic neuron group.
    conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
      The synaptic connections.
    comp_method: str
      The connection type used for model speed optimization. It can be
      `sparse` and `dense`. The default is `sparse`.
    delay_step: int, ArrayType, Initializer, Callable
      The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
    g_max: float, ArrayType, Initializer, Callable
      The synaptic strength. Default is 1.
    post_ref_key: str
      Whether the post-synaptic group has refractory period.
    """

    def __init__(
        self,
        pre: NeuDyn,
        post: NeuDyn,
        conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
        output: _SynOut = CUBA(target_var='V'),
        stp: Optional[_SynSTP] = None,
        comp_method: str = 'sparse',
        g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
        delay_step: Union[float, ArrayType, Initializer, Callable] = None,
        post_ref_key: str = None,
        name: str = None,
        mode: bm.Mode = None,
        stop_spike_gradient: bool = False,
    ):
        super().__init__(name=name,
                         pre=pre,
                         post=post,
                         conn=conn,
                         output=output,
                         stp=stp,
                         mode=mode)

        # parameters
        self.stop_spike_gradient = stop_spike_gradient
        self.post_ref_key = post_ref_key
        if post_ref_key:
            self.check_post_attrs(post_ref_key)
        self.comp_method = comp_method

        # connections and weights
        self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

        # register delay
        self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

    def update(self, pre_spike=None):
        # pre-synaptic spikes
        if pre_spike is None:
            pre_spike = self.pre.get_local_delay("spike", self.name)
        pre_spike = bm.as_jax(pre_spike)
        if self.stop_spike_gradient:
            pre_spike = jax.lax.stop_gradient(pre_spike)

        # update sub-components
        if self.stp is not None:
            self.stp.update(pre_spike)

        # synaptic values onto the post
        if isinstance(self.conn, All2All):
            syn_value = bm.asarray(pre_spike, dtype=bm.float_)
            if self.stp is not None:
                syn_value = self.stp(syn_value)
            post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
        elif isinstance(self.conn, One2One):
            syn_value = bm.asarray(pre_spike, dtype=bm.float_)
            if self.stp is not None:
                syn_value = self.stp(syn_value)
            post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
        else:
            if self.comp_method == 'sparse':
                if self.stp is not None:
                    syn_value = self.stp(pre_spike)
                    f = lambda s: bm.sparse.csrmv(
                        self.g_max, self.conn_mask[0], self.conn_mask[1], s,
                        shape=(self.pre.num, self.post.num), transpose=True
                    )
                else:
                    syn_value = pre_spike
                    f = lambda s: bm.event.csrmv(
                        self.g_max, self.conn_mask[0], self.conn_mask[1], s,
                        shape=(self.pre.num, self.post.num), transpose=True
                    )
                if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f)
                post_vs = f(syn_value)
            else:
                syn_value = bm.asarray(pre_spike, dtype=bm.float_)
                if self.stp is not None:
                    syn_value = self.stp(syn_value)
                post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
        if self.post_ref_key:
            post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))

        # update outputs
        return self.output(post_vs)


class Exponential(TwoEndConn):
    r"""Exponential decay synapse model.

    %s

    **Model Examples**

    - `(Brunel & Hakim, 1999) Fast Global Oscillation <https://brainpy-examples.readthedocs.io/en/latest/oscillation_synchronization/Brunel_Hakim_1999_fast_oscillation.html>`_
    - `(Vreeswijk & Sompolinsky, 1996) E/I balanced network <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Vreeswijk_1996_EI_net.html>`_
    - `(Brette, et, al., 2007) CUBA <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Brette_2007_CUBA.html>`_
    - `(Tian, et al., 2020) E/I Net for fast response <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Tian_2020_EI_net_for_fast_response.html>`_


    >>> import brainpy as bp
    >>> from brainpy import neurons, synapses, synouts
    >>> import matplotlib.pyplot as plt
    >>>
    >>> neu1 = neurons.LIF(1)
    >>> neu2 = neurons.LIF(1)
    >>> syn1 = synapses.Exponential(neu1, neu2, bp.conn.All2All(),
    >>>                             g_max=5., output=synouts.CUBA())
    >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
    >>>
    >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g'])
    >>> runner.run(150.)
    >>>
    >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
    >>> fig.add_subplot(gs[0, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
    >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
    >>> plt.legend()
    >>>
    >>> fig.add_subplot(gs[1, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
    >>> plt.legend()
    >>> plt.show()

    Parameters::

    pre: NeuGroup
      The pre-synaptic neuron group.
    post: NeuGroup
      The post-synaptic neuron group.
    conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
      The synaptic connections.
    comp_method: str
      The connection type used for model speed optimization. It can be
      `sparse` and `dense`. The default is `sparse`.
    delay_step: int, ArrayType, Initializer, Callable
      The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
    tau: float, ArrayType
      The time constant of decay. [ms]
    g_max: float, ArrayType, Initializer, Callable
      The synaptic strength (the maximum conductance). Default is 1.
    name: str
      The name of this synaptic projection.
    method: str
      The numerical integration methods.

    """

    def __init__(
        self,
        pre: NeuDyn,
        post: NeuDyn,
        conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
        output: Optional[_SynOut] = CUBA(),
        stp: Optional[_SynSTP] = None,
        comp_method: str = 'sparse',
        g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
        delay_step: Union[int, ArrayType, Initializer, Callable] = None,
        tau: Union[float, ArrayType] = 8.0,
        method: str = 'exp_auto',

        # other parameters
        name: str = None,
        mode: bm.Mode = None,
        stop_spike_gradient: bool = False,
    ):
        super().__init__(pre=pre,
                         post=post,
                         conn=conn,
                         output=output,
                         stp=stp,
                         name=name,
                         mode=mode)
        # parameters
        self.stop_spike_gradient = stop_spike_gradient

        # synapse dynamics
        self.syn = synapses.Expon(post.varshape, tau=tau, method=method)

        # Projection
        if isinstance(conn, All2All):
            self.comm = linear.AllToAll(pre.num, post.num, g_max)
        elif isinstance(conn, One2One):
            assert post.num == pre.num
            self.comm = linear.OneToOne(pre.num, g_max)
        else:
            if comp_method == 'dense':
                self.comm = linear.MaskedLinear(conn, g_max)
            elif comp_method == 'sparse':
                if self.stp is None:
                    self.comm = linear.EventCSRLinear(conn, g_max)
                else:
                    self.comm = linear.CSRLinear(conn, g_max)
            else:
                raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".')

        # delay
        self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

    @property
    def g(self):
        return self.syn.g

    @g.setter
    def g(self, value):
        self.syn.g = value

    def update(self, pre_spike=None):
        # delays
        if pre_spike is None:
            pre_spike = self.pre.get_local_delay("spike", self.name)
        pre_spike = bm.as_jax(pre_spike)
        if self.stop_spike_gradient:
            pre_spike = jax.lax.stop_gradient(pre_spike)

        # update sub-components
        self.output.update()
        if self.stp is not None:
            self.stp.update(pre_spike)
            pre_spike = self.stp(pre_spike)

        # post values
        g = self.syn(self.comm(pre_spike))

        # output
        return self.output(g)


Exponential.__doc__ = Exponential.__doc__ % (_docs.exp_syn_doc,)


class DualExponential(_TwoEndConnAlignPre):
    r"""Dual exponential synapse model.

    %s

    **Model Examples**

    >>> import brainpy as bp
    >>> from brainpy import neurons, synapses, synouts
    >>> import matplotlib.pyplot as plt
    >>>
    >>> neu1 = neurons.LIF(1)
    >>> neu2 = neurons.LIF(1)
    >>> syn1 = synapses.DualExponential(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
    >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
    >>>
    >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
    >>> runner.run(150.)
    >>>
    >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
    >>> fig.add_subplot(gs[0, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
    >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
    >>> plt.legend()
    >>>
    >>> fig.add_subplot(gs[1, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
    >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
    >>> plt.legend()
    >>> plt.show()

    Parameters::

    pre: NeuDyn
      The pre-synaptic neuron group.
    post: NeuDyn
      The post-synaptic neuron group.
    conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
      The synaptic connections.
    comp_method: str
      The connection type used for model speed optimization. It can be
      `sparse` and `dense`. The default is `sparse`.
    delay_step: int, ArrayType, Initializer, Callable
      The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
    tau_decay: float, ArrayArray, ndarray
      The time constant of the synaptic decay phase. [ms]
    tau_rise: float, ArrayArray, ndarray
      The time constant of the synaptic rise phase. [ms]
    g_max: float, ArrayType, Initializer, Callable
      The synaptic strength (the maximum conductance). Default is 1.
    name: str
      The name of this synaptic projection.
    method: str
      The numerical integration methods.

    """

    def __init__(
        self,
        pre: NeuDyn,
        post: NeuDyn,
        conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
        stp: Optional[_SynSTP] = None,
        output: _SynOut = None,  # CUBA(),
        comp_method: str = 'dense',
        g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
        tau_decay: Union[float, ArrayType] = 10.0,
        tau_rise: Union[float, ArrayType] = 1.,
        delay_step: Union[int, ArrayType, Initializer, Callable] = None,
        A: Optional[Union[float, ArrayType, Callable]] = None,
        method: str = 'exp_auto',

        # other parameters
        name: str = None,
        mode: bm.Mode = None,
        stop_spike_gradient: bool = False,
    ):

        # parameters
        self.stop_spike_gradient = stop_spike_gradient
        self.comp_method = comp_method
        self.tau_rise = tau_rise
        self.tau_decay = tau_decay
        if bm.size(self.tau_rise) != 1:
            raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. '
                             f'But we got {self.tau_rise}')
        if bm.size(self.tau_decay) != 1:
            raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
                             f'But we got {self.tau_decay}')

        syn = synapses.DualExpon(pre.size,
                                 pre.keep_size,
                                 A=A,
                                 mode=mode,
                                 tau_decay=tau_decay,
                                 tau_rise=tau_rise,
                                 method=method, )

        super().__init__(pre=pre,
                         post=post,
                         syn=syn,
                         conn=conn,
                         output=output,
                         stp=stp,
                         comp_method=comp_method,
                         g_max=g_max,
                         delay_step=delay_step,
                         name=name,
                         mode=mode)

        self.check_post_attrs('input')
        # copy the references
        self.g = syn.g
        self.h = syn.h

    def update(self, pre_spike=None):
        return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)


DualExponential.__doc__ = DualExponential.__doc__ % (_docs.dual_exp_syn_doc,)


class Alpha(_TwoEndConnAlignPre):
    r"""Alpha synapse model.

    %s

    **Model Examples**

    >>> import brainpy as bp
    >>> from brainpy import neurons, synapses, synouts
    >>> import matplotlib.pyplot as plt
    >>>
    >>> neu1 = neurons.LIF(1)
    >>> neu2 = neurons.LIF(1)
    >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
    >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
    >>>
    >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
    >>> runner.run(150.)
    >>>
    >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
    >>> fig.add_subplot(gs[0, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
    >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
    >>> plt.legend()
    >>> fig.add_subplot(gs[1, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
    >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
    >>> plt.legend()
    >>> plt.show()

    Parameters::

    pre: NeuDyn
      The pre-synaptic neuron group.
    post: NeuDyn
      The post-synaptic neuron group.
    conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
      The synaptic connections.
    comp_method: str
      The connection type used for model speed optimization. It can be
      `sparse` and `dense`. The default is `sparse`.
    delay_step: int, ArrayType, Initializer, Callable
      The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
    tau_decay: float, ArrayType
      The time constant of the synaptic decay phase. [ms]
    g_max: float, ArrayType, Initializer, Callable
      The synaptic strength (the maximum conductance). Default is 1.
    name: str
      The name of this synaptic projection.
    method: str
      The numerical integration methods.

    """

    def __init__(
        self,
        pre: NeuDyn,
        post: NeuDyn,
        conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
        output: _SynOut = None,  # CUBA(),
        stp: Optional[_SynSTP] = None,
        comp_method: str = 'dense',
        g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
        delay_step: Union[int, ArrayType, Initializer, Callable] = None,
        tau_decay: Union[float, ArrayType] = 10.0,
        method: str = 'exp_auto',

        # other parameters
        name: str = None,
        mode: bm.Mode = None,
        stop_spike_gradient: bool = False,
    ):
        # parameters
        self.stop_spike_gradient = stop_spike_gradient
        self.comp_method = comp_method
        self.tau_decay = tau_decay
        if bm.size(self.tau_decay) != 1:
            raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
                             f'But we got {self.tau_decay}')

        syn = synapses.Alpha(pre.size,
                             pre.keep_size,
                             mode=mode,
                             tau_decay=tau_decay,
                             method=method)

        super().__init__(pre=pre,
                         post=post,
                         syn=syn,
                         conn=conn,
                         comp_method=comp_method,
                         delay_step=delay_step,
                         g_max=g_max,
                         output=output,
                         stp=stp,
                         name=name,
                         mode=mode, )

        self.check_post_attrs('input')
        # copy the references
        self.g = syn.g
        self.h = syn.h

    def update(self, pre_spike=None):
        return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)


Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc,)


class NMDA(_TwoEndConnAlignPre):
    r"""NMDA synapse model.

    **Model Descriptions**

    The NMDA receptor is a glutamate receptor and ion channel found in neurons.
    The NMDA receptor is one of three types of ionotropic glutamate receptors,
    the other two being AMPA and kainate receptors.

    The NMDA receptor mediated conductance depends on the postsynaptic voltage.
    The voltage dependence is due to the blocking of the pore of the NMDA receptor
    from the outside by a positively charged magnesium ion. The channel is
    nearly completely blocked at resting potential, but the magnesium block is
    relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
    that are not blocked by magnesium can be fitted to

    .. math::

        g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V}
        \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}

    Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
    usually 1 mM. Thus, the channel acts as a
    "coincidence detector" and only once both of these conditions are met, the
    channel opens and it allows positively charged ions (cations) to flow through
    the cell membrane [2]_.

    If we make the approximation that the magnesium block changes
    instantaneously with voltage and is independent of the gating of the channel,
    the net NMDA receptor-mediated synaptic current is given by

    .. math::

        I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}

    where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
    reversal potential.

    Simultaneously, the kinetics of synaptic state :math:`g` is given by

    .. math::

        & g_\mathrm{NMDA} (t) = g_{max} g \\
        & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\
        & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k})

    where the decay time of NMDA currents is usually taken to be
    :math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms.

    The NMDA receptor has been thought to be very important for controlling
    synaptic plasticity and mediating learning and memory functions [3]_.


    **Model Examples**

    - `(Wang, 2002) Decision making spiking model <https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2002_decision_making_spiking.html>`_


    >>> import brainpy as bp
    >>> from brainpy import synapses, neurons
    >>> import matplotlib.pyplot as plt
    >>>
    >>> neu1 = neurons.HH(1)
    >>> neu2 = neurons.HH(1)
    >>> syn1 = synapses.NMDA(neu1, neu2, bp.connect.All2All())
    >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
    >>>
    >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
    >>> runner.run(150.)
    >>>
    >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
    >>> fig.add_subplot(gs[0, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
    >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
    >>> plt.legend()
    >>>
    >>> fig.add_subplot(gs[1, 0])
    >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
    >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x')
    >>> plt.legend()
    >>> plt.show()

    Parameters::

    pre: NeuDyn
      The pre-synaptic neuron group.
    post: NeuDyn
      The post-synaptic neuron group.
    conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
      The synaptic connections.
    comp_method: str
      The connection type used for model speed optimization. It can be
      `sparse` and `dense`. The default is `dense`.
    delay_step: int, ArrayType, Initializer, Callable
      The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
    g_max: float, ArrayType, Initializer, Callable
      The synaptic strength (the maximum conductance). Default is 1.
    tau_decay: float, ArrayType
      The time constant of the synaptic decay phase. Default 100 [ms]
    tau_rise: float, ArrayType
      The time constant of the synaptic rise phase. Default 2 [ms]
    a: float, ArrayType
      Default 0.5 ms^-1.
    name: str
      The name of this synaptic projection.
    method: str
      The numerical integration methods.

    References::

    .. [1] Brunel N, Wang X J. Effects of neuromodulation in a
           cortical network model of object working memory dominated
           by recurrent inhibition[J].
           Journal of computational neuroscience, 2001, 11(1): 63-85.
    .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
           Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
           Nature 438, no. 7065 (2005): 185-192.
    .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
           England journal of medicine, 361(3), p.302.
    .. [4] https://en.wikipedia.org/wiki/NMDA_receptor

    """

    def __init__(
        self,
        pre: NeuDyn,
        post: NeuDyn,
        conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
        output: _SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2),
        stp: Optional[_SynSTP] = None,
        comp_method: str = 'dense',
        g_max: Union[float, ArrayType, Initializer, Callable] = 0.15,
        delay_step: Union[int, ArrayType, Initializer, Callable] = None,
        tau_decay: Union[float, ArrayType] = 100.,
        a: Union[float, ArrayType] = 0.5,
        tau_rise: Union[float, ArrayType] = 2.,
        method: str = 'exp_auto',
        name: Optional[str] = None,
        mode: Optional[bm.Mode] = None,
        stop_spike_gradient: bool = False,
    ):
        # parameters
        self.tau_decay = tau_decay
        self.tau_rise = tau_rise
        self.a = a
        if bm.size(a) != 1:
            raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}')
        if bm.size(tau_decay) != 1:
            raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}')
        if bm.size(tau_rise) != 1:
            raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}')
        self.comp_method = comp_method
        self.stop_spike_gradient = stop_spike_gradient

        syn = synapses.NMDA(pre.size,
                            pre.keep_size,
                            mode=mode,
                            a=a,
                            tau_decay=tau_decay,
                            tau_rise=tau_rise,
                            method=method, )

        super().__init__(pre=pre,
                         post=post,
                         syn=syn,
                         conn=conn,
                         output=output,
                         stp=stp,
                         comp_method=comp_method,
                         g_max=g_max,
                         delay_step=delay_step,
                         name=name,
                         mode=mode)

        # copy the references
        self.g = syn.g
        self.x = syn.x

    def update(self, pre_spike=None):
        return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
